# 导入必要的库
import numpy as np  # 用于处理多维数组和矩阵运算
import faiss  # 用于高效相似性搜索和稠密向量聚类
from util import createXY  # 用于创建数据集的特征和标签
from sklearn.model_selection import train_test_split  # 用于拆分数据集为训练集和测试集
from sklearn.neighbors import KNeighborsClassifier  # sklearn中的K近邻分类器
import argparse  # 用于解析命令行参数
import logging  # 用于记录日志
from tqdm import tqdm  # 用于在循环中显示进度条
from FaissKNeighbors import FaissKNeighbors  # 导入自定义的FaissKNeighbors类

# 配置logging, 确保能够打印正在运行的函数名
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取命令行参数
def get_args():
    parser = argparse.ArgumentParser(description='使用CPU或GPU训练模型。') # 创建命令行参数解析器对象
    parser.add_argument('-m', '--mode', type=str, required=True, choices=['cpu', 'gpu'], help='选择训练模式：CPU或GPU。')
    parser.add_argument('-f', '--feature', type=str, required=True, choices=['flat', 'vgg'], help='选择特征提取方法：flat或vgg。')
    parser.add_argument('-l', '--library', type=str, required=True, choices=['sklearn', 'faiss'], help='选择使用的库：sklearn或faiss。')
    args = parser.parse_args()
    return args

# 主函数，运行训练过程
def main():
    args = get_args()
    
    # 根据mode初始化FAISS所需的资源
    if args.mode == 'gpu':
        print("正在初始化FAISS的GPU资源...")
        try:
            res = faiss.StandardGpuResources()
            print("FAISS的GPU资源初始化完成。")
        except Exception as e:
            print(f"无法初始化FAISS的GPU资源: {e}")
            print("切换到CPU模式")
            res = None
            args.mode = 'cpu'
    else:
        res = None
    logging.info(f"选择模式是 {args.mode.upper()}")
    logging.info(f"选择特征提取方法是 {args.feature.upper()}")
    logging.info(f"选择使用的库是 {args.library.upper()}")

    # 载入和预处理数据
    print("正在载入和预处理数据...")
    X, y = createXY(train_folder="/mnt/cgshare/data/train/", dest_folder=".", method=args.feature)
    X = np.array(X).astype('float32')
    print("数据载入完成，正在进行L2归一化...")
    faiss.normalize_L2(X)  # 对数据进行L2归一化
    print("数据L2归一化完成。")
    y = np.array(y)
    logging.info("数据加载和预处理完成。")

    # 数据集分割为训练集和测试集
    print("正在分割数据集为训练集和测试集...")
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2023)
    print("数据集划分完成。")
    logging.info("数据集划分为训练集和测试集。")

    # 初始化变量，跟踪最佳的k值和相应的准确率
    best_k = -1
    best_accuracy = 0.0
    print("初始化变量完成，开始寻找最佳K值...")

    # 寻找最佳K值
    for k in range(1, 21):
        print(f"正在训练K值为{k}的模型...")
        if args.library == 'faiss':
            model = FaissKNeighbors(k=k, gpu=args.mode == 'gpu')
        else:
            model = KNeighborsClassifier(n_neighbors=k)
        
        print(f"正在拟合模型，K值为{k}...")
        model.fit(X_train, y_train)
        print(f"模型拟合完成，K值为{k}。")

        print(f"正在评估模型，K值为{k}...")
        accuracy = model.score(X_test, y_test)
        print(f"K值为{k}的模型准确率为{accuracy:.4f}")
        
        if accuracy > best_accuracy:
            best_k = k
            best_accuracy = accuracy
            print(f"发现新的最佳K值：{best_k}，准确率：{best_accuracy:.4f}")

    print(f"最佳K值为{best_k}，准确率为{best_accuracy:.4f}")
    logging.info(f"最佳K值为{best_k}，准确率为{best_accuracy:.4f}")

if __name__ == "__main__":
    main()